import jax
import jax.numpy as jnp
import json
import argparse
import os, sys
root = os.path.abspath(os.path.join(os.getcwd(), '..', '..', 'sae-jax'))  
sys.path.insert(0, root)

root = os.path.abspath(os.path.join(os.getcwd(), '..', '..', 'lsh'))  
sys.path.insert(0, root)

root = os.path.abspath(os.path.join(os.getcwd(), '..', '..', 'sae-softmax'))  
sys.path.insert(0, root)

from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from datasets import load_dataset
from tqdm import tqdm
import numpy as np
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt

def parse_args():
    parser = argparse.ArgumentParser(description='Sample last token embeddings')
    parser.add_argument('--model_name', type=str, default="google/gemma-7b",
                      help='Name of the base model to use')
    parser.add_argument('--dataset_name', type=str, default="bookcorpus",
                      help='Name of the dataset to use')
    parser.add_argument('--total_samples', type=int, default=1000,
                      help='Total number of samples to evaluate')
    parser.add_argument('--save_every', type=int, default=50,
                      help='Save results every N samples')
    parser.add_argument('--cache_dir', type=str, default="~/gemma_cache",
                      help='Directory to cache model files')
    parser.add_argument('--output_dir', type=str, 
                      default="~/results/embeddings",
                      help='Directory to save output files')
    parser.add_argument('--load_embeddings', action='store_true',
                      help='Load existing embeddings instead of generating new ones')
    parser.add_argument('--embeddings_path', type=str,
                      help='Path to the embeddings numpy file')
    parser.add_argument('--plot_only', action='store_true',
                      help='Only perform plotting, skip embedding generation/loading')
    parser.add_argument('--plot_title', type=str, default='PCA Projection of Embeddings and Unembeddings',
                      help='Title for the plot')

    return parser.parse_args()

def plot_embeddings(output_embeddings, last_token_embeddings, output_dir, plot_title):
    # Perform PCA
    pca = PCA(n_components=2)
    # Fit PCA on output embeddings
    pca.fit(output_embeddings)

    # Transform both sets of embeddings
    output_embeddings_2d = pca.transform(output_embeddings)
    last_token_embeddings_2d = pca.transform(last_token_embeddings)

    # Create the plot
    plt.figure(figsize=(12, 8))
    plt.scatter(output_embeddings_2d[:, 0], output_embeddings_2d[:, 1], 
               c='blue', alpha=0.5, label='Unembeddings', s=10)
    plt.scatter(last_token_embeddings_2d[:, 0], last_token_embeddings_2d[:, 1], 
               c='red', alpha=0.5, label='Embeddings', s=10)

    plt.title(plot_title, fontsize=20)
    plt.xlabel('First Principal Component', fontsize=16)
    plt.ylabel('Second Principal Component', fontsize=16)
    plt.legend(fontsize=16)
    plt.grid(True, alpha=0.3)

    # Save the plot
    plt.savefig(os.path.join(output_dir, 'embeddings_pca_plot.png'), dpi=300, bbox_inches='tight')
    plt.close()

    # Save the PCA-transformed embeddings
    np.save(os.path.join(output_dir, 'output_embeddings_pca.npy'), output_embeddings_2d)
    np.save(os.path.join(output_dir, 'last_token_embeddings_pca.npy'), last_token_embeddings_2d)

    print(f"Saved PCA plot and transformed embeddings to {output_dir}")

def main():
    args = parse_args()
    
    # Create output directory if it doesn't exist
    os.makedirs(args.output_dir, exist_ok=True)

    if args.plot_only:
        if not args.embeddings_path:
            raise ValueError("--embeddings_path must be provided when using --plot_only")
        
        # Load embeddings
        last_token_embeddings = np.load(args.embeddings_path)
        
        # Load or get output embeddings
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        gemma_model = AutoModelForCausalLM.from_pretrained(args.model_name, cache_dir=args.cache_dir).to(device)
        output_embeddings = gemma_model.get_output_embeddings().weight.detach().cpu().numpy()
        
        # Plot
        plot_embeddings(output_embeddings, last_token_embeddings, args.output_dir, args.plot_title)
        return

    if args.load_embeddings:
        if not args.embeddings_path:
            raise ValueError("--embeddings_path must be provided when using --load_embeddings")
        last_token_embeddings = np.load(args.embeddings_path)
    else:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        gemma_tokenizer = AutoTokenizer.from_pretrained(args.model_name, cache_dir=args.cache_dir)
        gemma_model = AutoModelForCausalLM.from_pretrained(args.model_name, cache_dir=args.cache_dir).to(device)

        ds = load_dataset(args.dataset_name, split="train")
        total_samples = args.total_samples
        save_every = args.save_every

        # Initialize list to store embeddings
        last_token_embeddings = []

        cnt = 0
        for sentence in tqdm(ds["text"]):
            prompt = sentence[:len(sentence)//4]  # Take first quarter of sentence
            inputs = gemma_tokenizer(prompt, return_tensors="pt").to(device)
            
            with torch.no_grad():
                outputs = gemma_model(**inputs, output_hidden_states=True)
            
            last_hidden_state = outputs.hidden_states[-1]
            last_token_embedding = last_hidden_state[0, -1, :].cpu().numpy()
            last_token_embeddings.append(last_token_embedding)
            
            cnt += 1

            # Save occasionally
            if cnt % save_every == 0 or cnt == total_samples:
                np.save(os.path.join(args.output_dir, 'last_token_embeddings_partial.npy'), np.array(last_token_embeddings))
                print(f"Checkpoint saved at sample {cnt}")

            if cnt >= total_samples:
                break

        last_token_embeddings = np.array(last_token_embeddings)

    # Get output embeddings from the model
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    gemma_model = AutoModelForCausalLM.from_pretrained(args.model_name, cache_dir=args.cache_dir).to(device)
    output_embeddings = gemma_model.get_output_embeddings().weight.detach().cpu().numpy()

    # Plot embeddings
    plot_embeddings(output_embeddings, last_token_embeddings, args.output_dir, args.plot_title)

if __name__ == "__main__":
    main()


